#include "cell.h"
#include "swarch.h"
#define cell_workload pe_mask
// #define BASE_COST (1734*4)
#define BASE_COST (6647)
#ifdef __sw_host__
// #include <athread.h>
#include <qthread.h>
// #define LWPF_UNITS
#define LWPF_UNITS \
  U(NONB)          \
  U(FIND_BONDS)    \
  U(CELL_IE) U(RIGID) U(LISTED)
#include "lwpf3/lwpf.h"
#include <mpi.h>
#include <cassert>
#include "memory_cpp.hpp"
int *lbal_cnt;
//long text1_got[4096] __attribute__((section(".data")));
INLINE long count_pairs(cellgrid_t *grid, celldata_t *icell, celldata_t *jcell) {
  long ncheck = 0, ncut = 0;
  vec<real> dcell;
  vecsubv(dcell, jcell->basis, icell->basis);

  for (int i = 0; i < icell->natom; i++) {
    for (int j = 0; j < jcell->natom; j++) {
      ncheck++;
      vec<real> datom;
      vecsubv(datom, jcell->x[j], icell->x[i]);
      vecaddv(datom, datom, dcell);
      real r2 = vecnorm2(datom);
      if (r2 < grid->rcut * grid->rcut) {
        ncut++;
      }
    }
  }
  if (icell == jcell) {
    // ncheck /= 2;
    ncut /= 2;
  }
  return ncut;
}

INLINE int try_decomp(cellgrid_t *grid, long max_workload) {
  long accum_workload = 0;
  int cur_pe = 0;
  FOREACH_LOCAL_CELL(grid, cx, cy, cz, cell) {
    if (cell->cell_workload > max_workload)
      return 0;
    if (accum_workload + cell->cell_workload > max_workload) {
      accum_workload = 0;
      cur_pe++;
      if (cur_pe > 63)
        return 0;
    }
    accum_workload += cell->cell_workload;
  }
  return cur_pe <= 63;
}
INLINE long bisect_decomp(cellgrid_t *grid, long tot_workload) {
  long min_load = 0, max_load = tot_workload;
  while (min_load != max_load) {
    long mid_load = (min_load + max_load) >> 1;
    if (!try_decomp(grid, mid_load)) {
      min_load = mid_load + 1;
    } else {
      max_load = mid_load;
    }
  }
  return min_load;
}
extern void slave_estimate_cell_workload(cellgrid_t *grid);
void do_lbal(cellgrid_t *grid) {
  int rank;
  MPI_Comm_rank(MPI_COMM_WORLD, &rank);
  sw_archdata_t *arch_data = (sw_archdata_t*)grid->arch_data;
  long tot_workload = 0;
  qthread_spawn(slave_estimate_cell_workload, grid);
  qthread_join();
  FOREACH_LOCAL_CELL(grid, cx, cy, cz, icell) {
    tot_workload += icell->cell_workload;
  }
  long perpe_workload = bisect_decomp(grid, tot_workload);
  long accum_workload = 0;
  int cur_pe = 0;
  arch_data->pe_range[0].st = 0;
  long max_workload = 0;
  FOREACH_LOCAL_CELL(grid, cx, cy, cz, icell) {
    /* Assign a new pe for further workload */
    if (accum_workload + icell->cell_workload > perpe_workload) {
      assert(cur_pe < 63);
      arch_data->pe_range[cur_pe].ed = getcid(grid, cx, cy, cz);
      if (rank == 0) printf("pe: %d load: %ld cell: %d\n", cur_pe, accum_workload, arch_data->pe_range[cur_pe].ed - arch_data->pe_range[cur_pe].st);
      if (accum_workload > max_workload) {
        max_workload = accum_workload;
      }
      accum_workload = 0;
      cur_pe++;
      arch_data->pe_range[cur_pe].st = getcid(grid, cx, cy, cz);
    }
    accum_workload += icell->cell_workload;
  }
  arch_data->pe_range[cur_pe].ed = getcid(grid, grid->nlocal.x, grid->nlocal.y, grid->nlocal.z) + 1;
  cur_pe++;
  while (cur_pe <= 63) {
    arch_data->pe_range[cur_pe].st = getcid(grid, grid->nlocal.x, grid->nlocal.y, grid->nlocal.z) + 1;
    arch_data->pe_range[cur_pe].ed = getcid(grid, grid->nlocal.x, grid->nlocal.y, grid->nlocal.z) + 1;
    cur_pe++;
  }
  FOREACH_CELL(grid, cx, cy, cz, cell) {
    cell->pe_mask = 0;
  }

  cur_pe = 0;
  FOREACH_LOCAL_CELL(grid, cx, cy, cz, icell) {
    int cellid = getcid(grid, cx, cy, cz);
    icell->rep_init_mask = 0;
    while (arch_data->pe_range[cur_pe].ed <= cellid)
      cur_pe++;
    FOREACH_NEIGHBOR_HALF_SHELL(grid, cx, cy, cz, dx, dy, dz, jcell) {

      if ((jcell->pe_mask & (1L << cur_pe)) == 0) {
        int did = hdcell(grid->nn, dx, dy, dz);
        assert(did < 64);
        icell->rep_init_mask |= 1L << did;
        jcell->pe_mask |= 1L << cur_pe;
      }
    }
  }
  int nreps = 0, ncells = 0;
  FOREACH_CELL(grid, cx, cy, cz, cell) {
    cell->first_frep = nreps;
    nreps += max(__builtin_popcountl(cell->pe_mask) - 1, 0);
    ncells++;
  }
  if (arch_data->freps) {
    arch_data->freps = esmd::reallocate(arch_data->freps, nreps);
  } else {
    arch_data->freps = esmd::allocate<frep_t>(nreps, "sw/force replicas");
  }
  arch_data->nreps = nreps;
  if (rank == 0){
    printf("%ld %ld\n", max_workload, tot_workload / 64);
    printf("nreps=%d, ncells=%d\n", nreps, ncells);
  }
}
void swinit(cellgrid_t *grid) {
  sw_archdata_t *arch_data = esmd::allocate<sw_archdata_t>(1, "sw/archdata");
  // arch_data->last_lbal = -2000;
  grid->arch_data = arch_data;
  arch_data->freps = NULL;
  arch_data->last_lbal = -1;
  do_lbal(grid);

  // qthread_init();
  evt_conf_t conf;
#ifdef __sw7__
  conf.pc_mask = 0x11;
  conf.evt[0] = PC0_CYCLE;
  conf.evt[4] = PC4_REQ_RSCATTER_MEMACC;
#elif defined(__sw5__)
  conf.pc_mask = 0xc;
  conf.evt[2] = PC2_CNT_GLD;
  conf.evt[3] = PC3_CYCLE;
#elif defined(__sw9__)
  conf.pc_mask = 0x11;
  conf.evt[0] = PC0_CYCLE;
  conf.evt[4] = PC4_REQ_RSCATTER_MEMACC;
#endif
  // lwpf_init(&conf);
  lwpf_init(NULL);

  arch_data->pack_params = esmd::allocate<vec_pack_param_t>(N_PACK_DIRS, "sw/pack params");
  // esmd_icalloc(arch_data->flocks, vecvol(grid->nall), "sw/force locks");
  grid->arch_data = arch_data;
#ifdef __sw7__
  lbal_cnt = 0x700000023e30L;//libc_uncached_malloc(4);
#elif defined(__sw9__)
  lbal_cnt = __libc_malloc_uncached(4);
#elif defined(__sw5__)
  lbal_cnt = esmd::allocate<int>(1, "load balance counter");
#endif
}
void swfinal() {
  int rank;
  MPI_Comm_rank(MPI_COMM_WORLD, &rank);
  char fn[256];
  sprintf(fn, "lwpf.%06d.txt", rank);
  FILE *out = fopen(fn, "w");
  lwpf_report_summary(out);
  fclose(out);
  if (rank == 0) lwpf_report_summary(stdout);
}
#endif
#ifdef __sw_slave__
#include "dma_macros_new.h"
#include <qthread_slave.h>
#include <stdio.h>
void estimate_cell_workload(cellgrid_t *grid) {
  dma_init();
  cellgrid_t lgrid;
  pe_getn(grid, &lgrid, 1);
  dma_syn();
  real rcut2 = lgrid.rcut * lgrid.rcut;
  FOREACH_LOCAL_CELL_CPE_RR(&lgrid, cx, cy, cz, icell) {
    cellmeta_t imeta;
    pe_getn(&icell->basis, &imeta, 1);
    dma_syn();
    vec<real> xi[CELL_CAP];
    pe_getn(icell->x, xi, imeta.natom);
    dma_syn();

    int pair_workload = 0;
    FOREACH_NEIGHBOR_HALF_SHELL(&lgrid, cx, cy, cz, dx, dy, dz, jcell) {
      cellmeta_t jmeta;
      pe_getn(&jcell->basis, &jmeta, 1);
      dma_syn();
      vec<real> xj[CELL_CAP];
      pe_getn(jcell->x, xj, jmeta.natom);
      dma_syn();
      long ncheck = 0, ncut = 0;
      vec<real> dcell;
      vecsubv(dcell, jmeta.basis, imeta.basis);

      for (int i = 0; i < imeta.natom; i++) {
        for (int j = 0; j < jmeta.natom; j++) {
          ncheck++;
          vec<real> datom;
          vecsubv(datom, xj[j], xi[i]);
          vecaddv(datom, datom, dcell);
          real r2 = vecnorm2(datom);
          if (r2 < rcut2) {
            ncut++;
          }
        }
      }
      if (icell == jcell) {
        ncut /= 2;
      }
      
      pair_workload += ncut;
    }
    icell->cell_workload = pair_workload + BASE_COST;
    // tot_workload += icell->cell_workload;
  }
}

#endif
